-
Notifications
You must be signed in to change notification settings - Fork 17
[Memory optm] loss using torch + compile #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
return logprobs | ||
|
||
# Convert to fp32 for numerical stability | ||
scaled_logits_fp32 = scaled_logits.float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob question: what's the dtype for scaled_logits
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float becomes torch.float32
@ebsmothers @Jack-Khuu @joecummings @pbontrager can some of you confirm that i dont need to do the all_gather that was happening in |
import torch.nn.functional as F | ||
|
||
|
||
def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also delete this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its used in 3 other places. I will leave it there for now. We will prob need some larger refactoring later to clean up / organize losses
|
Co-authored-by: Jiyue Wang <[email protected]>
|
||
# compile loss | ||
logger.info("Compiling loss") | ||
self.loss = torch.compile(self.loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any circumstance under which this command would fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cant think of one in our scenario, but if/when this happens, we can fix it
logprobs = selective_log_softmax(scaled_logits, input_ids) | ||
return logprobs | ||
|
||
# Cast up to fp32 for numerical stability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would change this to something like "ensure logits are in fp32" b/c they actually could already be in fp32 and no need for "Casting up"
Co-authored-by: Joe Cummings <[email protected]>
…into compile_loss
Memory freebies

I dont think that loss/reward is a good way to check correctness here. But i compared the functions locally and they provide the same output.